#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SQL备份脚本生成器
功能：
1. 读取SQL DDL文件
2. 按分号分割SQL语句
3. 去掉每个语句的第一行
4. 给表名添加_20250625_bak后缀
5. 生成备份和还原语句
"""

import re
import os
from typing import List, Tuple


def read_sql_file(file_path: str) -> str:
    """读取SQL文件内容"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        return content
    except FileNotFoundError:
        print(f"错误：文件 {file_path} 不存在")
        return ""
    except Exception as e:
        print(f"读取文件时出错：{e}")
        return ""


def split_sql_statements(content: str) -> List[str]:
    """按分号分割SQL语句"""
    # 按分号分割，但保留分号
    statements = []
    parts = content.split(';')
    
    for i, part in enumerate(parts):
        if part.strip():  # 跳过空白部分
            if i < len(parts) - 1:  # 不是最后一部分，添加分号
                statements.append(part.strip() + ';')
            else:  # 最后一部分，如果有内容也添加
                statements.append(part.strip())
    
    return statements


def remove_first_line(sql_statement: str) -> str:
    """去掉SQL语句的第一行"""
    lines = sql_statement.split('\n')
    if len(lines) > 1:
        return '\n'.join(lines[1:])
    return sql_statement


def extract_table_name(create_table_sql: str) -> str:
    """从CREATE TABLE语句中提取表名"""
    # 匹配CREATE TABLE语句中的表名
    pattern = r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?`?([a-zA-Z_][a-zA-Z0-9_]*)`?'
    match = re.search(pattern, create_table_sql, re.IGNORECASE)
    
    if match:
        return match.group(1)
    return None


def add_backup_suffix(sql_statement: str, table_name: str) -> str:
    """给表名添加_20250625_bak后缀"""
    backup_table_name = f"{table_name}_20250625_bak"
    
    # 替换CREATE TABLE语句中的表名
    pattern = r'(CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?)`?{}`?'.format(re.escape(table_name))
    replacement = r'\1`{}`'.format(backup_table_name)
    
    modified_sql = re.sub(pattern, replacement, sql_statement, flags=re.IGNORECASE)
    return modified_sql, backup_table_name


def generate_backup_statement(original_table: str, backup_table: str) -> str:
    """生成备份语句"""
    return f"INSERT INTO `{backup_table}` SELECT * FROM `{original_table}`;"


def generate_restore_statements(original_table: str, backup_table: str) -> Tuple[str, str]:
    """生成还原语句"""
    temp_name = f"{original_table}_temp"
    
    # 第一步：重命名当前表为临时名
    rename1 = f"ALTER TABLE `{original_table}` RENAME TO `{temp_name}`;"
    
    # 第二步：重命名备份表为原表名
    rename2 = f"ALTER TABLE `{backup_table}` RENAME TO `{original_table}`;"
    
    return rename1, rename2


def process_sql_file(file_path: str):
    """处理SQL文件的主函数"""
    print(f"开始处理文件: {file_path}")
    
    # 读取文件
    content = read_sql_file(file_path)
    if not content:
        return
    
    # 分割SQL语句
    statements = split_sql_statements(content)
    print(f"找到 {len(statements)} 个SQL语句")
    
    # 处理结果
    modified_ddl_statements = []
    backup_statements = []
    restore_statements = []
    table_pairs = []  # 存储原表名和备份表名的对应关系
    
    for i, statement in enumerate(statements):
        print(f"处理第 {i+1} 个语句...")
        
        # 去掉第一行
        modified_statement = remove_first_line(statement)
        
        # 提取表名
        table_name = extract_table_name(modified_statement)
        
        if table_name:
            print(f"  找到表名: {table_name}")
            
            # 添加备份后缀
            backup_ddl, backup_table_name = add_backup_suffix(modified_statement, table_name)
            modified_ddl_statements.append(backup_ddl)
            
            # 生成备份语句
            backup_stmt = generate_backup_statement(table_name, backup_table_name)
            backup_statements.append(backup_stmt)
            
            # 生成还原语句
            restore1, restore2 = generate_restore_statements(table_name, backup_table_name)
            restore_statements.extend([restore1, restore2])
            
            # 记录表名对应关系
            table_pairs.append((table_name, backup_table_name))
        else:
            print(f"  警告：无法从语句中提取表名")
            modified_ddl_statements.append(modified_statement)
    
    # 生成输出文件
    base_dir = os.path.dirname(file_path)
    base_name = os.path.splitext(os.path.basename(file_path))[0]
    
    # 1. 修改后的DDL文件
    ddl_output_path = os.path.join(base_dir, f"{base_name}_backup_ddl.sql")
    with open(ddl_output_path, 'w', encoding='utf-8') as f:
        f.write("-- 备份表DDL语句\n")
        f.write("-- 生成时间: 2025-06-25\n\n")
        for stmt in modified_ddl_statements:
            f.write(stmt)
            f.write("\n\n")
    
    # 2. 备份数据语句文件
    backup_output_path = os.path.join(base_dir, f"{base_name}_backup_data.sql")
    with open(backup_output_path, 'w', encoding='utf-8') as f:
        f.write("-- 数据备份语句\n")
        f.write("-- 生成时间: 2025-06-25\n\n")
        for stmt in backup_statements:
            f.write(stmt)
            f.write("\n")
    
    # 3. 还原语句文件
    restore_output_path = os.path.join(base_dir, f"{base_name}_restore.sql")
    with open(restore_output_path, 'w', encoding='utf-8') as f:
        f.write("-- 数据还原语句\n")
        f.write("-- 生成时间: 2025-06-25\n")
        f.write("-- 注意：执行前请确保备份表存在\n\n")
        for stmt in restore_statements:
            f.write(stmt)
            f.write("\n")
    
    # 4. 生成汇总报告
    report_path = os.path.join(base_dir, f"{base_name}_report.txt")
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write("SQL备份脚本处理报告\n")
        f.write("=" * 50 + "\n")
        f.write(f"处理时间: 2025-06-25\n")
        f.write(f"源文件: {file_path}\n")
        f.write(f"处理的表数量: {len(table_pairs)}\n\n")
        
        f.write("表名对应关系:\n")
        f.write("-" * 30 + "\n")
        for original, backup in table_pairs:
            f.write(f"{original} -> {backup}\n")
        
        f.write(f"\n生成的文件:\n")
        f.write(f"1. DDL文件: {ddl_output_path}\n")
        f.write(f"2. 备份语句: {backup_output_path}\n")
        f.write(f"3. 还原语句: {restore_output_path}\n")
        f.write(f"4. 处理报告: {report_path}\n")
    
    print("\n处理完成！")
    print(f"生成的文件:")
    print(f"1. 备份表DDL: {ddl_output_path}")
    print(f"2. 数据备份语句: {backup_output_path}")
    print(f"3. 数据还原语句: {restore_output_path}")
    print(f"4. 处理报告: {report_path}")


if __name__ == "__main__":
    # 指定SQL文件路径
    sql_file_path = r"C:\Users\56209\PycharmProjects\pythonProject\sqlscript\temp\all_tables_ddl_starrocks-第二批.sql"
    
    # 处理文件
    process_sql_file(sql_file_path)
